#!/usr/bin/env python

"""
This is a helper program designed for use with the NASA GeneLab 
metagenomics assembly-based workflow. It takes as input multiple individual-
sample annotation and taxonomic coverage tables and creates 4 output merged tables:
    1) summed based on KO annotations
    2) summed based on KO annotations and normalized with the specified method (set with -n flag)
    3) summed based on taxonomic classification
    4) summed based on taxonomic classification and normalized with the specified method (set with -n flag)

These tables are produced during the GeneLab Illumina metagenomics processing workflow. The location of this protocol 
document be added here as soon as it is publically available.

### details on normalization methods
Normalization for sampling depth is done by either coverage per million (CPM) or using the median-ratio (MR)
method as performed in DESeq2. 

I initially wrote this for normalizing metagenomic coverage data, like gene-level coverage, or summed KO coverages. 
These are normalized for gene-length already because they are "coverages", but they are not yet normalized 
for sampling depth – which is where this script comes in. 

I also found myself wanting this because I wanted to do differential abundance testing of coverages
of KO terms. DESeq2 doesn't require normalizing for gene-length because it is the same unit being analyzed
across all samples – the same gene, so the same size. However, after grouping genes into their KO annotations,
(which we may need to compare across samples that don't all share the same underlying assembly or genes), or by
taxonomy, they no longer all represent the same units across all samples. It is because of this I decided to 
stick with gene-level coverages (which are normalized for gene-length), and then sum those values based on KO 
annotations or taxonomy.

The CPM (coverage per million) normalization is just like a percent, except scaled to 1 million instead of 100.
So each row's entry (e.g. gene/KO/taxon/etc.) is the proportion out of 1 million for that column (sample), 
and each column will sum to 1 million.

The median-ration normalization method (MR) was initially described in this paper 
(http://dx.doi.org/10.1186/gb-2010-11-10-r106; e.q. 5), and this site is super-informative in general 
about the DESeq2 process overall, and helped me understand the normalizaiton process better to implement it: 
https://hbctraining.github.io/DGE_workshop/lessons/02_DGE_count_normalization.html. Columns will not sum to 
the same amount when the median-ratio method is applied.
"""

import os
import sys
import argparse
import textwrap
import pandas as pd
from math import isnan
from numpy import NaN
import numpy as np
from scipy.stats.mstats import gmean

parser = argparse.ArgumentParser(description="This is a helper program designed for use with the NASA GeneLab \
                                             metagenomics assembly-based workflow. See note at top of script for more info. \
                                             For version info, run `bit-version`.")

required = parser.add_argument_group('required arguments')

required.add_argument("input_tables", metavar="input-tables", type=str, nargs="+", help='Input coverage, \
                      annotation, and tax tables (as written, expected to end with extension ".tsv")')

parser.add_argument("-n", "--normalization-method", help='Desired normalization method of either "none", \
                    "CPM" as in coverage per million, or "MR" as in median-ratio as performed in DESeq2. \
                    See note at top of program for more info. (default: "CPM")', choices=["CPM", "MR"], \
                    action="store", default="CPM")

parser.add_argument("-o", "--output-prefix", help='Desired output prefix (default: "Combined")', \
                    action="store", default="Combined", dest="output_prefix")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

################################################################################


def main():

    check_all_inputs_exist(args.input_tables)

    input_files, sample_names = setup_input_lists(args.input_tables)

    KO_dict, tax_dict = {}, {}

    na_taxids = []

    KO_collapsed_tabs, tax_collapsed_tabs, KO_dict, tax_dict = process_each_table(input_files, KO_dict, tax_dict, na_taxids)

    combined_KO_tab, combined_tax_tab = combine_tabs(KO_collapsed_tabs, tax_collapsed_tabs, sample_names, KO_dict, tax_dict)

    norm_KO_tab, norm_tax_tab = normalize_tabs(combined_KO_tab, combined_tax_tab, args.normalization_method)

    # writing outputs
    combined_KO_tab.to_csv(args.output_prefix + "-gene-level-KO-function-coverages.tsv", index=False, sep="\t")
    norm_KO_tab.to_csv(args.output_prefix + "-gene-level-KO-function-coverages-" + args.normalization_method + ".tsv", index=False, sep="\t")

    combined_tax_tab.to_csv(args.output_prefix + "-gene-level-taxonomy-coverages.tsv", index=False, sep="\t")
    norm_tax_tab.to_csv(args.output_prefix + "-gene-level-taxonomy-coverages-" + args.normalization_method + ".tsv", index=False, sep="\t")

################################################################################


# setting some colors
tty_colors = {
    'green' : '\033[0;32m%s\033[0m',
    'yellow' : '\033[0;33m%s\033[0m',
    'red' : '\033[0;31m%s\033[0m'
}


### functions ###
def color_text(text, color='green'):
    if sys.stdout.isatty():
        return tty_colors[color] % text
    else:
        return text


def wprint(text):
    """ print wrapper """

    print(textwrap.fill(text, width=80, initial_indent="  ", 
          subsequent_indent="  ", break_on_hyphens=False))


def check_all_inputs_exist(input_tables):

    for file in input_tables:
        if not os.path.exists(file):
            print("")
            wprint(color_text("It seems the specified input file '" + str(file) + "' can't be found.", "yellow"))
            print("\nExiting for now.\n")
            sys.exit(1)


def setup_input_lists(input_tables):
    """ setting up input lists for file locations and sample names """

    input_files = []
    sample_names = []

    for sample_path in input_tables:

        # checking here that the file has more than just the headers, otherwise it is empty and we are dropping it here
        with open(sample_path) as in_tab:

            # just making sure there are at least 2 lines, otherwise moving onto the next
            try:
                [next(in_tab) for x in range(2)]
            except:
                continue

        input_files.append(sample_path)
        sample = os.path.splitext(os.path.basename(sample_path))[0]

        # removing the common superfluous info that my workflow introduces to the names, if it is there
        sample = sample.replace("-gene-coverage-annotation-and-tax", "")

        sample_names.append(sample)

    return(input_files, sample_names)


def add_to_KO_dict(table, KO_dict):
    """ function for building KO mapping dictionary """

    for index, row in table.iterrows():
        
        if str(row["KO_ID"]).startswith("K"):
            
            if str(row["KO_ID"]) not in KO_dict:
                
                KO_dict[row["KO_ID"]] = row["KO_function"]

    return(KO_dict)


def add_to_tax_dict(table, tax_dict, na_taxids):
    """ function for building tax mapping dictionary """

    for index, row in table.iterrows():
        
        # skipping if not classified
        if not pd.isna(row["taxid"]):
            
            if not row["taxid"] in tax_dict:
                tax_dict[row["taxid"]] = row[["domain", "phylum", "class", "order", "family", "genus", "species"]].tolist()

                # some taxids have all NA for these ranks (like 1 and 131567), keep track so can sum with not classified
                if len(set(row[["domain", "phylum", "class", "order", "family", "genus", "species"]].tolist())) == 1:
                    na_taxids.append(row["taxid"])


    return(tax_dict, na_taxids)


def process_each_table(input_files, KO_dict, tax_dict, na_taxids):
    """ reads in each table, normalizes coverage values, collapses based on KO annotations """

    KO_collapsed_tabs = []
    tax_collapsed_tabs = []

    # iterator to access the same input file and sample name
    for i in range(len(input_files)):
        

        tab = pd.read_csv(input_files[i], sep="\t", dtype = {'taxid': str}, low_memory = False)

        KO_dict = add_to_KO_dict(tab, KO_dict)

        tax_dict, na_taxids = add_to_tax_dict(tab, tax_dict, na_taxids)

        # collapsing based on KO terms
        KO_tab = tab[['KO_ID', 'coverage']].groupby(by = ['KO_ID'], dropna = False).sum()
        
        KO_collapsed_tabs.append(KO_tab)

        # collapsing based on tax
            # first setting any taxids that are all NA at these standard ranks to "NA"
        tab.replace(na_taxids, NaN, inplace = True)
        tax_tab = tab[['taxid', 'coverage']].groupby(by = ['taxid'], dropna = False).sum()
        tax_collapsed_tabs.append(tax_tab)


    return(KO_collapsed_tabs, tax_collapsed_tabs, KO_dict, tax_dict)


def add_KO_functions(tab, KO_dict):
    """ adds KO functions to combined table based on KO_ID and KO_dict object holding mappings """

    KO_functions = []

    for KO in tab.KO_ID:

        if KO in KO_dict:

            KO_functions.append(str(KO_dict[KO]))

        else:

            KO_functions.append("Not annotated")

    tab.insert(1, "KO_function", KO_functions)

    return(tab)


def add_tax_info(tab, tax_dict):
    """ adds lineage info back to combined table based on taxid and tax_dict object holding mappings """

    domain_list, phylum_list, class_list, order_list, family_list, genus_list, species_list = [], [], [], [], [], [], []

    for taxid in tab.taxid:
        
        if taxid in tax_dict:
            
            if isinstance(tax_dict[taxid][0], str):
                domain_list.append(tax_dict[taxid][0])
            else:
                domain_list.append("NA")

            if isinstance(tax_dict[taxid][1], str):
                phylum_list.append(tax_dict[taxid][1])
            else:
                phylum_list.append("NA")

            if isinstance(tax_dict[taxid][2], str):
                class_list.append(tax_dict[taxid][2])
            else:
                class_list.append("NA")

            if isinstance(tax_dict[taxid][3], str):
                order_list.append(tax_dict[taxid][3])
            else:
                order_list.append("NA")

            if isinstance(tax_dict[taxid][4], str):
                family_list.append(tax_dict[taxid][4])
            else:
                family_list.append("NA")

            if isinstance(tax_dict[taxid][5], str):
                genus_list.append(tax_dict[taxid][5])
            else:
                genus_list.append("NA")


            if isinstance(tax_dict[taxid][6], str):
                species_list.append(tax_dict[taxid][6])
            else:
                species_list.append("NA")
                
        else:
            domain_list.append("NA")
            phylum_list.append("NA")
            class_list.append("NA")
            order_list.append("NA")
            family_list.append("NA")
            genus_list.append("NA")
            species_list.append("NA")

    tab.insert(1, "domain", domain_list)
    tab.insert(2, "phylum", phylum_list)
    tab.insert(3, "class", class_list)
    tab.insert(4, "order", order_list)
    tab.insert(5, "family", family_list)
    tab.insert(6, "genus", genus_list)
    tab.insert(7, "species", species_list)

    return(tab)


def combine_tabs(KO_tab_list, tax_tab_list, sample_names, KO_dict, tax_dict):
    """ combines all KO tables into one and all tax tables into one """

    # combining KO tabs
    KO_combined_tab = pd.concat(KO_tab_list, axis=1).drop_duplicates().fillna(0).sort_index()
    # moving index to be column and changing that NaN to be "Not annotated"
    KO_combined_tab = KO_combined_tab.reset_index().fillna("Not annotated")


    # setting column names
    KO_combined_tab.columns = ["KO_ID"] + sample_names
    

    # adding KO functions
    KO_combined_tab = add_KO_functions(KO_combined_tab, KO_dict)


    # combining tax tabs
    tax_combined_tab = pd.concat(tax_tab_list, axis=1).drop_duplicates().fillna(0).sort_index()
    # moving index to be column
    tax_combined_tab = tax_combined_tab.reset_index()
    # setting column names
    tax_combined_tab.columns = ["taxid"] + sample_names
    # changing taxid NaN to be "Not annotated"
    tax_combined_tab['taxid'] = tax_combined_tab['taxid'].fillna("Not classified")
    # adding tax full lineage info
    tax_combined_tab = add_tax_info(tax_combined_tab, tax_dict)

    return(KO_combined_tab, tax_combined_tab)


def median_ratio_norm(tab):
    """ performs median ratio normalization """

    ## calculating size factors
    # getting geometric means for each row
    with np.errstate(divide = 'ignore'):
        geomeans = gmean(tab, axis = 1)

    # getting ratios of gene values to geometric means
    ratios_tab = (tab.T / geomeans).T

    sizeFactors = ratios_tab[geomeans > 0].median().to_list()

    # dividing by size factors
    norm_tab = tab / sizeFactors

    return(norm_tab)


def normalize_tabs(combined_KO_tab, combined_tax_tab, normalization_method):
    """ generates normalized versions of merged output tables """

    target_cols = [col for col in combined_KO_tab.columns.to_list() if col not in ["KO_ID", "KO_function"]]
    
    norm_KO_tab = combined_KO_tab.copy()
    norm_tax_tab = combined_tax_tab.copy()

    if normalization_method == "CPM":

        norm_KO_tab[target_cols] = norm_KO_tab[target_cols] / norm_KO_tab[target_cols].sum() * 1000000
        norm_tax_tab[target_cols] = norm_tax_tab[target_cols] / norm_tax_tab[target_cols].sum() * 1000000

    else:

        norm_KO_tab[target_cols] = median_ratio_norm(norm_KO_tab[target_cols])
        norm_tax_tab[target_cols] = median_ratio_norm(norm_tax_tab[target_cols])

    return(norm_KO_tab, norm_tax_tab)


if __name__ == "__main__":
    main()
